/*
 * Copyright 2020 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Convert LLVM PIC ABI to emscripten ABI
//
// When generating -fPIC code llvm will generate imports call GOT.mem and
// GOT.func in order to access the addresses of external global data and
// functions.
//
// However emscripten uses a different ABI where function and data addresses
// are available at runtime via special `g$foo` and `fp$bar` function calls.
//
// Here we internalize all such wasm globals and generte code that sets their
// value based on the result of call `g$foo` and `fp$bar` functions at runtime.
//
// A function called `__assign_got_enties` is generated by this pass that
// performs all the assignments.
//

#include "abi/js.h"
#include "asm_v_wasm.h"
#include "ir/import-utils.h"
#include "ir/table-utils.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/debug.h"

#define DEBUG_TYPE "emscripten-pic"

namespace wasm {

static Global* ensureGlobalImport(Module* module, Name name, Type type) {
  // See if its already imported.
  // FIXME: O(N)
  ImportInfo info(*module);
  if (auto* g = info.getImportedGlobal(ENV, name)) {
    return g;
  }
  // Failing that create a new import.
  auto import = new Global;
  import->name = name;
  import->module = ENV;
  import->base = name;
  import->type = type;
  module->addGlobal(import);
  return import;
}

static Function*
ensureFunctionImport(Module* module, Name name, Signature sig) {
  // See if its already imported.
  // FIXME: O(N)
  ImportInfo info(*module);
  if (auto* f = info.getImportedFunction(ENV, name)) {
    return f;
  }
  // Failing that create a new import.
  auto import = new Function;
  import->name = name;
  import->module = ENV;
  import->base = name;
  import->sig = sig;
  module->addFunction(import);
  return import;
}

struct EmscriptenPIC : public WalkerPass<PostWalker<EmscriptenPIC>> {

  EmscriptenPIC(bool sideModule) : sideModule(sideModule) {}

  void visitGlobal(Global* curr) {
    if (!curr->imported()) {
      return;
    }
    if (curr->module == "GOT.func") {
      gotFuncEntries.push_back(curr);
    } else if (curr->module == "GOT.mem") {
      gotMemEntries.push_back(curr);
    } else {
      return;
    }
    // Make this an internal, non-imported, global.
    curr->module.clear();
    curr->init = Builder(*getModule()).makeConst(int32_t(0));
  }

  void visitModule(Module* module) {
    BYN_TRACE("generateAssignGOTEntriesFunction\n");
    if (!gotFuncEntries.size() && !gotMemEntries.size()) {
      return;
    }

    Builder builder(*getModule());
    Function* assignFunc = builder.makeFunction(
      ASSIGN_GOT_ENTRIES, std::vector<NameType>{}, Type::none, {});
    Block* block = builder.makeBlock();
    assignFunc->body = block;

    bool hasSingleMemorySegment =
      module->memory.exists && module->memory.segments.size() == 1;

    for (Global* g : gotMemEntries) {
      // If this global is defined in this module, we export its address
      // relative to the relocatable memory. If we are in a main module, we can
      // just use that location (since if other modules have this symbol too, we
      // will "win" as we are loaded first). Otherwise, import a g$ getter. Note
      // that this depends on memory having a single segment, so we know the
      // offset, and that the export is a global.
      auto base = g->base;
      if (hasSingleMemorySegment && !sideModule) {
        if (auto* ex = module->getExportOrNull(base)) {
          if (ex->kind == ExternalKind::Global) {
            // The base relative to which we are computed is the offset of the
            // singleton segment.
            auto* relativeBase = ExpressionManipulator::copy(
              module->memory.segments[0].offset, *module);

            auto* offset = builder.makeGlobalGet(
              ex->value, module->getGlobal(ex->value)->type);
            auto* add = builder.makeBinary(AddInt32, relativeBase, offset);
            GlobalSet* globalSet = builder.makeGlobalSet(g->name, add);
            block->list.push_back(globalSet);
            continue;
          }
        }
      }
      Name getter(std::string("g$") + base.c_str());
      ensureFunctionImport(module, getter, Signature(Type::none, Type::i32));
      Expression* call = builder.makeCall(getter, {}, Type::i32);
      GlobalSet* globalSet = builder.makeGlobalSet(g->name, call);
      block->list.push_back(globalSet);
    }

    ImportInfo importInfo(*module);

    // We may have to add things to the table.
    Global* tableBase = nullptr;

    for (Global* g : gotFuncEntries) {
      // The function has to exist either as export or an import.
      // Note that we don't search for the function by name since its internal
      // name may be different.
      auto* ex = module->getExportOrNull(g->base);
      // If this is exported then it must be one of the functions implemented
      // here, and if this is a main module, then we can simply place the
      // function in the table: the loader will see it there and resolve all
      // other uses to this one.
      if (ex && !sideModule) {
        assert(ex->kind == ExternalKind::Function);
        auto* f = module->getFunction(ex->value);
        if (f->imported()) {
          Fatal() << "GOT.func entry is both imported and exported: "
                  << g->base;
        }
        // The base relative to which we are computed is the offset of the
        // singleton segment, which we must ensure exists
        if (!tableBase) {
          tableBase = ensureGlobalImport(module, TABLE_BASE, Type::i32);
        }
        if (!module->table.exists) {
          module->table.exists = true;
        }
        if (module->table.segments.empty()) {
          module->table.segments.resize(1);
          module->table.segments[0].offset =
            builder.makeGlobalGet(tableBase->name, Type::i32);
        }
        auto tableIndex =
          TableUtils::getOrAppend(module->table, f->name, *module);
        auto* c = LiteralUtils::makeFromInt32(tableIndex, Type::i32, *module);
        auto* getBase = builder.makeGlobalGet(tableBase->name, Type::i32);
        auto* add = builder.makeBinary(AddInt32, getBase, c);
        auto* globalSet = builder.makeGlobalSet(g->name, add);
        block->list.push_back(globalSet);
        continue;
      }
      // This is imported or in a side module. Create an fp$ import to get the
      // function table index from the dynamic loader.
      auto* f = importInfo.getImportedFunction(ENV, g->base);
      if (!f) {
        if (!ex) {
          Fatal() << "GOT.func entry with no import/export: " << g->base;
        }
        f = module->getFunction(ex->value);
      }
      Name getter(
        (std::string("fp$") + g->base.c_str() + std::string("$") + getSig(f))
          .c_str());
      ensureFunctionImport(module, getter, Signature(Type::none, Type::i32));
      auto* call = builder.makeCall(getter, {}, Type::i32);
      auto* globalSet = builder.makeGlobalSet(g->name, call);
      block->list.push_back(globalSet);
    }

    module->addFunction(assignFunc);
  }

  std::vector<Global*> gotFuncEntries;
  std::vector<Global*> gotMemEntries;
  bool sideModule;
};

Pass* createEmscriptenPICPass() { return new EmscriptenPIC(true); }

Pass* createEmscriptenPICMainModulePass() { return new EmscriptenPIC(false); }

} // namespace wasm
